/*
 * Copyright (c) 2018, apexes.net. All rights reserved.
 *
 *         http://www.apexes.net
 *
 */
package net.apexes.commons.ormlite;

import com.j256.ormlite.dao.BaseDaoImpl;
import com.j256.ormlite.dao.Dao;
import com.j256.ormlite.dao.DaoManager;
import com.j256.ormlite.db.DatabaseType;
import com.j256.ormlite.field.FieldType;
import com.j256.ormlite.logger.Logger;
import com.j256.ormlite.logger.LoggerFactory;
import com.j256.ormlite.misc.SqlExceptionUtil;
import com.j256.ormlite.stmt.StatementBuilder.StatementType;
import com.j256.ormlite.support.CompiledStatement;
import com.j256.ormlite.support.ConnectionSource;
import com.j256.ormlite.support.DatabaseConnection;
import com.j256.ormlite.table.TableInfo;
import com.j256.ormlite.table.TableUtils;
import net.apexes.commons.lang.Checks;

import java.lang.reflect.Field;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * @author <a href=mailto:hedyn@foxmail.com>HeDYn</a>
 */
public class UpgradeHelper {

    private static final Logger LOG = LoggerFactory.getLogger(UpgradeHelper.class);
    private static final FieldType[] NO_FIELD_TYPES = new FieldType[0];

    private final ConnectionSource connectionSource;
    private final UpgradeChecker checker;

    public UpgradeHelper(ConnectionSource connectionSource, UpgradeChecker checker) {
        Checks.verifyNotNull(connectionSource, "connectionSource");
        Checks.verifyNotNull(checker, "checker");
        this.connectionSource = connectionSource;
        this.checker = checker;
    }

    public <T, ID> int upgrade(Table<T> table) throws Exception {
        if (!checker.exists(connectionSource, table)) {
            return TableUtils.createTable(connectionSource, table.config());
        }

        Set<String> newFieldNames = new HashSet<>();
        Field[] fields = table.getClass().getDeclaredFields();
        for (Field field : fields) {
            if (Column.class.isAssignableFrom(field.getType())) {
                Column column = (Column) field.get(table);
                if (!checker.exists(connectionSource, column, field)) {
                    newFieldNames.add(field.getName());
                }
            }
        }
        if (Checks.isEmpty(newFieldNames)) {
            return 0;
        }

        Dao<T, ID> dao = DaoManager.createDao(connectionSource, table.config());
        DatabaseType databaseType = connectionSource.getDatabaseType();
        if (dao instanceof BaseDaoImpl<?, ?>) {
            return doAlterTableAppendColumn(databaseType, ((BaseDaoImpl<?, ?>) dao).getTableInfo(), newFieldNames);
        } else {
            TableInfo<T, ID> tableInfo = new TableInfo<>(databaseType, null, table.config());
            return doAlterTableAppendColumn(databaseType, tableInfo, newFieldNames);
        }
    }

    private <T, ID> int doAlterTableAppendColumn(DatabaseType databaseType,
                                                 TableInfo<T, ID> tableInfo,
                                                 Set<String> newFieldNames) throws Exception {
        List<String> statements = new ArrayList<>();
        List<String> queriesAfter = new ArrayList<>();
        List<String> additionalArgs = new ArrayList<>();
        List<String> statementsBefore = new ArrayList<>();
        List<String> statementsAfter = new ArrayList<>();
        StringBuilder sb = new StringBuilder(256);
        sb.append("ALTER TABLE ");
        databaseType.appendEscapedEntityName(sb, tableInfo.getTableName());
        sb.append(" ADD COLUMN ");
        String alterCommand = sb.toString();
        for (FieldType fieldType : tableInfo.getFieldTypes()) {
            // skip foreign collections
            if (fieldType.isForeignCollection()) {
                continue;
            }
            Field field = fieldType.getField();
            if (newFieldNames.contains(field.getName())) {
                sb.setLength(0);
                sb.append(alterCommand);
                String columnDefinition = fieldType.getColumnDefinition();
                if (columnDefinition == null) {
                    // we have to call back to the database type for the specific create syntax
                    databaseType.appendColumnArg(tableInfo.getTableName(), sb, fieldType, additionalArgs,
                            statementsBefore, statementsAfter, queriesAfter);
                } else {
                    // hand defined field
                    databaseType.appendEscapedEntityName(sb, fieldType.getColumnName());
                    sb.append(' ').append(columnDefinition).append(' ');
                }
                String sql = sb.toString();
                statements.add(sql);
            }
        }

        DatabaseConnection connection = connectionSource.getReadWriteConnection(null);
        try {
            return doStatements(connection, "alter table add column", statements, false,
                    databaseType.isCreateTableReturnsNegative(),
                    databaseType.isCreateTableReturnsZero());
        } finally {
            connectionSource.releaseConnection(connection);
        }
    }

    private static int doStatements(DatabaseConnection connection,
                                    String label,
                                    Collection<String> statements,
                                    boolean ignoreErrors,
                                    boolean returnsNegative,
                                    boolean expectingZero) throws Exception {
        int stmtC = 0;
        for (String statement : statements) {
            int rowC = 0;
            CompiledStatement compiledStmt = null;
            try {
                compiledStmt =
                        connection.compileStatement(statement, StatementType.EXECUTE, NO_FIELD_TYPES,
                                DatabaseConnection.DEFAULT_RESULT_FLAGS, false);
                rowC = compiledStmt.runExecute();
                LOG.info(statement);
                LOG.debug("executed {} table statement changed {} rows: {}", label, rowC, statement);
            } catch (SQLException e) {
                if (ignoreErrors) {
                    LOG.info("ignoring {} error '{}' for statement: {}", label, e, statement);
                } else {
                    throw SqlExceptionUtil.create("SQL statement failed: " + statement, e);
                }
            } finally {
                if (compiledStmt != null) {
                    compiledStmt.close();
                }
            }
            // sanity check
            if (rowC < 0) {
                if (!returnsNegative) {
                    throw new SQLException("SQL statement " + statement + " updated " + rowC
                            + " rows, we were expecting >= 0");
                }
            } else if (rowC > 0 && expectingZero) {
                throw new SQLException("SQL statement updated " + rowC
                        + " rows, we were expecting == 0: " + statement);
            }
            stmtC++;
        }
        return stmtC;
    }
}
